# pylint: disable=missing-module-docstring
# pylint: disable=no-name-in-module

import copy
from pathlib import Path
from typing import Callable, Optional

import hydra
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
import wandb
from omegaconf import DictConfig, OmegaConf
from torch import nn
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm

from diffusion_bandit import utils
from diffusion_bandit.dataset_generation import ShapeXDataset
from diffusion_bandit.diffusion import DiffusionProcess
from diffusion_bandit.neural_networks.shape_score_nets import get_score_network


@hydra.main(
    version_base=None, config_path="configs", config_name="train_shape_score_model"
)
def main(config: DictConfig) -> None:
    print(OmegaConf.to_yaml(config))
    wandb.init(
        config=OmegaConf.to_container(config, resolve=True, throw_on_missing=False),
        project=config.wandb.project,
        tags=config.wandb.tags,
        anonymous=config.wandb.anonymous,
        mode=config.wandb.mode,
        dir=Path(config.wandb.dir).absolute(),
    )
    utils.seeding.seed_everything(config)
    # Set up device
    device = torch.device(config.device)
    print(device)

    # Load dataset
    if config.modality == "shape":
        # Load shape data
        dataset_path = Path(config.data_dir) / f"{config.dataset.name}"
        dataset_dict = torch.load(dataset_path, weights_only=False)
        x_data = dataset_dict["x_data"]
        d_ext = dataset_dict["dataset_config"]["dataset"]["d_ext"]
        d_int = dataset_dict["dataset_config"]["dataset"]["d_int"]
        surface = dataset_dict["dataset_config"]["dataset"]["surface"]

        if surface:
            d_int -= 1

        # Ensure x_data is a torch tensor
        x_data = torch.as_tensor(x_data)
        assert not torch.isnan(x_data).any(), "Input data contains NaNs."
        assert not torch.isinf(x_data).any(), "Input data contains infinite values."

        # For shape, we use uniform sampling (no weighting)
        dataset = ShapeXDataset(x_data)

    elif config.modality == "protein":
        dataset_path = Path(config.data_dir) / "std_protein_regression_dataset.pt"
        x_data, _ = torch.load(dataset_path, weights_only=False)
        d_ext = 1280
        dataset = ShapeXDataset(x_data)

    else:
        raise NotImplementedError(f"modality: {config.modality} not supported")

    # Splitting dataset into training and testing sets
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    # Create DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.training.batch_size,
        shuffle=True,  # Shuffle only if no sampler is used
        pin_memory=(
            True if device.type == "cuda" else False
        ),  # Enable pin_memory only if using CUDA
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config.training.batch_size,
        shuffle=False,  # Never shuffle the test set
        pin_memory=True if device.type == "cuda" else False,
    )

    # Initialize the diffusion process
    diffusion_process = DiffusionProcess(
        beta_min=config.diffusion.beta_min,
        beta_max=config.diffusion.beta_max,
    )

    # Set up score model
    score_model = get_score_network(
        d_ext=d_ext,
        diffusion_process=diffusion_process,
        **config.score_model,
    ).to(device)

    print(
        f"Number of parameters: \
            {sum(p.numel() for p in score_model.parameters() if p.requires_grad)}"
    )

    # Set up EMA model if enabled
    if config.ema.use_ema:
        ema_model = copy.deepcopy(score_model)
        for param in ema_model.parameters():
            param.requires_grad = False  # EMA model parameters are not trainable

    # Set up optimizer and scheduler
    optimizer = getattr(optim, config.optimizer.type)(
        score_model.parameters(),
        **{k: v for k, v in config.optimizer.items() if k != "type"},
    )
    if config.scheduler.type != "no":
        scheduler = getattr(optim.lr_scheduler, config.scheduler.type)(
            optimizer, **{k: v for k, v in config.scheduler.items() if k != "type"}
        )
    else:
        scheduler = None

    # Training loop
    for epoch in range(config.training.n_epochs):
        train_loss = train_epoch(
            score_model,
            train_loader,
            optimizer,
            diffusion_process.loss_fn,
            device,
            epoch,
            config.training.n_epochs,
            ema_model=ema_model if config.ema.use_ema else None,
            ema_decay=config.ema.decay if config.ema.use_ema else None,
        )
        eval_loss = evaluate_epoch(
            score_model, test_loader, diffusion_process.loss_fn, device
        )

        # Evaluate EMA model if enabled
        if config.ema.use_ema:
            ema_eval_loss = evaluate_epoch(
                ema_model, test_loader, diffusion_process.loss_fn, device
            )

        if scheduler:
            scheduler.step()

        # Log losses to wandb
        log_dict = {
            "train_loss": train_loss,
            "eval_loss": eval_loss,
            "learning_rate": optimizer.param_groups[0]["lr"],
        }
        if config.ema.use_ema:
            log_dict["ema_eval_loss"] = ema_eval_loss
        wandb.log(log_dict)

        print(log_dict)

        if epoch % 5 == 0:
            model_save_path = (
                Path(config.outputs_dir) / f"{epoch}{config.save.name}.pth"
            )
            save_dict = {
                "score_model": score_model,
                "beta_min": config.diffusion.beta_min,
                "beta_max": config.diffusion.beta_max,
            }
            if config.ema.use_ema:
                save_dict["ema_model"] = ema_model
            torch.save(save_dict, model_save_path)
            print(f"Model and dataset saved to {model_save_path}")


def train_epoch(
    model: nn.Module,
    train_loader: DataLoader,
    optimizer: optim.Optimizer,
    loss_fn: Callable,
    device: torch.device,
    epoch: int,
    n_epochs: int,
    ema_model: Optional[nn.Module] = None,
    ema_decay: float = 0.999,
    verbose: bool = False,
) -> float:
    model.train()
    total_loss = 0.0
    num_items = 0

    for batch_idx, x_batch in tqdm(enumerate(train_loader)):
        x_batch = x_batch.to(device)
        optimizer.zero_grad()

        loss = loss_fn(model, x_batch)
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"NaN or Inf in loss at Epoch {epoch}, Batch {batch_idx}")
            continue  # Skip this batch

        loss.backward()

        # Check gradients
        for name, param in model.named_parameters():
            if param.grad is not None:
                if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                    print(f"NaN or Inf in gradients at {name}")
                    continue  # Skip this batch

        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Update EMA model if enabled
        if ema_model is not None:
            update_ema(model, ema_model, ema_decay)

        total_loss += loss.item() * x_batch.shape[0]
        num_items += x_batch.shape[0]

        if verbose and batch_idx % 10 == 0:
            print(
                f"Epoch {epoch+1}/{n_epochs}, Batch {batch_idx}, Loss: {loss.item():.5f}"
            )

    epoch_loss = total_loss / num_items
    return epoch_loss


def evaluate_epoch(
    model: nn.Module,
    test_loader: DataLoader,
    loss_fn: Callable,
    device: torch.device,
    verbose: bool = False,
) -> float:
    """
    Evaluate the model on the test set.

    Args:
        model (nn.Module): The model to be evaluated.
        test_loader (DataLoader): Dataloader for the test set.
        loss_fn (Callable): Loss function for the model.
        device (torch.device): Device to use for evaluation.
        verbose (bool, optional): If True, print the evaluation loss. Defaults to False.

    Returns:
        float: Average loss on the test set.
    """
    model.eval()
    total_loss = 0.0
    num_items = 0

    with torch.no_grad():
        for idx, x_batch in enumerate(test_loader):
            x_batch = x_batch.to(device)
            loss = loss_fn(model, x_batch)
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"NaN or Inf in eval loss at Batch {idx}")
                continue
            total_loss += loss.item() * x_batch.shape[0]
            num_items += x_batch.shape[0]

    eval_loss = total_loss / num_items
    if verbose:
        print(f"Evaluation Loss: {eval_loss:.5f}")
    return eval_loss


def update_ema(model: nn.Module, ema_model: nn.Module, decay: float):
    """
    Update the EMA model parameters.

    Args:
        model (nn.Module): The current model with updated parameters.
        ema_model (nn.Module): The EMA model to be updated.
        decay (float): The decay rate for EMA.
    """
    with torch.no_grad():
        model_state = model.state_dict()
        ema_state = ema_model.state_dict()
        for key in model_state.keys():
            ema_state[key].data.mul_(decay).add_(model_state[key].data, alpha=1 - decay)


if __name__ == "__main__":
    main()  # pylint: disable=no-value-for-parameter
